"""
#
# flow stability for dynamic community detection https://arxiv.org/abs/2101.06131
#
# Copyright (C) 2021 Alexandre Bovet <alexandre.bovet@maths.ox.ac.uk>
#
# This program is free software; you can redistribute it and/or modify it under
# the terms of the GNU Lesser General Public License as published by the Free
# Software Foundation; either version 3 of the License, or (at your option) any
# later version.
#
# This program is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
# details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.


This script computes and plots the comparison of the flow stability and 
multilayer modularity in Fig. 4.

"""

import sys
import os
PACKAGE_PARENT = '..'
SCRIPT_DIR = os.path.dirname(os.path.realpath(os.path.join(os.getcwd(), os.path.expanduser(__file__))))
sys.path.append(os.path.normpath(os.path.join(SCRIPT_DIR, PACKAGE_PARENT)))

import numpy as np
from TemporalNetwork import ContTempNetwork
import pickle
import matplotlib.pyplot as plt
    



def Round_To_n(x, n):
    return round(x, -int(np.floor(np.sign(x) * np.log10(abs(x)))) + n)

raise Exception
#%%



savedir = '../paper_data/synthtempnet/toy_model_comparison'

figdir =  '../figures/paper_figures'

os.makedirs(savedir, exist_ok=True)

os.makedirs(figdir, exist_ok=True)

filename = 'synthtemp_toymod_comp'


num_sims = 10


        
#%% load networks        
    
with open(os.path.join(savedir, filename + '_sim_param.pickle'), 'rb') as fopen:

    sim_param = pickle.load(fopen)

nets = []
for i in range(num_sims):
    nets.append(ContTempNetwork.load(os.path.join(savedir, filename + f'{i:02d}' + '.pickle')))
    
    
    
#%% clustering


num_repeat = 50

    
    
    
#%%    
with open(os.path.join(savedir, filename + '_flow_stab_res.pickle'), 'rb') as fopen:
    res_flowstab = pickle.load(fopen)


#%%    
with open(os.path.join(savedir, filename + '_multislice_res.pickle'), 'rb') as fopen:
    multislice_res = pickle.load(fopen)
    

#%% flowstab nmi and num clusts
plt.figure()
sim=0

plt.errorbar(res_flowstab['lambdas'], [np.mean([res_flowstab['avg_nmi_forw'][sim][i] for sim in range(num_sims)])\
                               for i in range(len(res_flowstab['lambdas']))],
             yerr=[np.std([res_flowstab['avg_nmi_forw'][sim][i] for sim in range(num_sims)])\
                                            for i in range(len(res_flowstab['lambdas']))],
                 capsize=10)
    
plt.errorbar(res_flowstab['lambdas'], [np.mean([res_flowstab['avg_nmi_back'][sim][i] for sim in range(num_sims)])\
                               for i in range(len(res_flowstab['lambdas']))],
             yerr=[np.std([res_flowstab['avg_nmi_back'][sim][i] for sim in range(num_sims)])\
                                            for i in range(len(res_flowstab['lambdas']))],
                 capsize=10)
plt.xscale('log')

plt.figure()
plt.errorbar(res_flowstab['lambdas'], [np.mean([res_flowstab['num_clusts_forw'][sim][i] for sim in range(num_sims)])\
                               for i in range(len(res_flowstab['lambdas']))],
             yerr=[np.std([res_flowstab['num_clusts_forw'][sim][i] for sim in range(num_sims)])\
                                            for i in range(len(res_flowstab['lambdas']))],
                 capsize=10)
    
plt.errorbar(res_flowstab['lambdas'], [np.mean([res_flowstab['num_clusts_back'][sim][i] for sim in range(num_sims)])\
                               for i in range(len(res_flowstab['lambdas']))],
             yerr=[np.std([res_flowstab['num_clusts_back'][sim][i] for sim in range(num_sims)])\
                                            for i in range(len(res_flowstab['lambdas']))],
                 capsize=10)
plt.xscale('log')



#%% multislice nmi and num clusts
plt.figure()
sim=0

plt.errorbar(multislice_res['res_params'][sim], [np.mean([multislice_res['avg_nmi_multi'][sim][i] for sim in range(num_sims)])\
                               for i in range(len(multislice_res['res_params'][sim]))],
             yerr=[np.std([multislice_res['avg_nmi_multi'][sim][i] for sim in range(num_sims)])\
                                            for i in range(len(multislice_res['res_params'][sim]))],
                 capsize=10)
plt.xscale('log')

plt.figure()
plt.errorbar(multislice_res['res_params'][sim], [np.mean([multislice_res['num_clusts_multi'][sim][i] for sim in range(num_sims)])\
                               for i in range(len(multislice_res['res_params'][sim]))],
             yerr=[np.std([multislice_res['num_clusts_multi'][sim][i] for sim in range(num_sims)])\
                                            for i in range(len(multislice_res['res_params'][sim]))],
                 capsize=10)
plt.xscale('log')


#%%
for sim in range(num_sims):
    sim=8
    clust_matrix = np.array(multislice_res['best_clusts_multi'][sim][7]).reshape((5,
                                                              nets[0].num_nodes)).T
    plt.figure()
    plt.imshow(clust_matrix)

# pd.to_pickle(res,os.path.join(datadir, 'multislice_results' + f'{sim:02d}' + '.pickle'))


# plt.savefig(os.path.join(figdir, filename + '_multislice_5.png'))


#%% multislice results
from collections import Counter
from TemporalStability import Partition

res_idx = 7
multislice_res['best_clusts_multi'][sim][res_idx]
count = Counter([tuple(multislice_res['best_clusts_multi'][sim][res_idx]) for sim in range(num_sims)])

count.most_common(10)


clust_list = [Partition(num_nodes=(5)*nets[0].num_nodes,
                                node_to_cluster_dict={i:c for i,c in enumerate(c)}).cluster_list for c in \
              [multislice_res['best_clusts_multi'][sim][res_idx] for sim in range(num_sims)]]

countm = Counter([tuple(sorted([tuple(sorted(list(c))) for c in c_list])) for c_list in clust_list])
countm.most_common(10)
#%% flow stab results

lambda_idx = 0

print('forw\n','\n'.join([str(res_flowstab['best_clusts_forw'][sim][lambda_idx]) for sim in range(num_sims)]))
print('back\n','\n'.join([str(res_flowstab['best_clusts_back'][sim][lambda_idx]) for sim in range(num_sims)]))



lambdas = res_flowstab['lambdas']
for lambda_idx in range(len(lambdas)):
    print(f'tau = {1/lambdas[lambda_idx]}')
    cforw = Counter([tuple(sorted([tuple(sorted(list(c))) for c in res_flowstab['best_clusts_forw'][sim][lambda_idx]])) for sim in range(num_sims)])
    cback = Counter([tuple(sorted([tuple(sorted(list(c))) for c in res_flowstab['best_clusts_back'][sim][lambda_idx]])) for sim in range(num_sims)])
    print(cforw.most_common(1)[0])
    print(cback.most_common(1)[0])

#%% sankey plot

import pandas as pd
# small scale

source_comms = [{0,1},{2,3},{4,5},{6,7}]
target_comms = [{0,1},{2,3},{4,5},{6,7}]

source_comms = [{0,1,2,3},{4,5,6,7}]
target_comms = [{0,1,6,7},{2,3,4,5}]

# for fast

class_dict = {i : {i} for i in range(8)}

flows = []
for clas, clas_set in class_dict.items():
    for s, comm_s in enumerate(source_comms):
        for t, comm_t in enumerate(target_comms):
            val = len(clas_set.intersection(comm_s).intersection(comm_t))
            if val > 0:
                flows.append({'source': int(s), 'target': int(t), 'type': int(clas), 'value': int(val)})

df_flows = pd.DataFrame.from_dict(flows)

# start from 1
df_flows.source += 1
df_flows.target += 1
df_flows.type += 1

import floweaver as flo

color_list =["#4ba706",
"#a2007e",
"#806dcb",
"#5eb275",
"#ca3b01",
"#01a4d6",
"#b77600",
"#a39643",
"#cc6ea9",
"#1e5e39",
"#cb5b5a"]

plt.figure()
plt.axes()
for i,col in enumerate(color_list):
    rectangle = plt.Rectangle((i*1,0), 0.5, 0.5, fc=color_list[i])
    plt.gca().add_patch(rectangle)

plt.xlim((0,11))    
sources_list = df_flows.source.unique().tolist()

sources_list = np.array(sources_list)

targets_list = df_flows.target.unique().tolist()

targets_list = np.array(targets_list)

class_list = sorted(df_flows.type.unique().tolist())


nodes = {
    'forward': flo.ProcessGroup(sources_list.tolist()),
    'backward': flo.ProcessGroup(targets_list.tolist()),
}

ordering = [
    ['forward'],       # put "forward" on the left...
    ['backward'],   # ... and "backward" on the right.
]

bundles = [
    flo.Bundle('forward', 'backward'),
]
# The first argument is the dimension name -- for now we're using
# "process" to group by process ids. The second argument is a list
# of groups.
source_part = flo.Partition.Simple('process', sources_list.tolist())

# This is another partition.
target_part = flo.Partition.Simple('process', targets_list.tolist())

# Update the ProcessGroup nodes to use the partitions
nodes['forward'].partition = source_part
nodes['backward'].partition = target_part
nodes['forward'].title = 'Forward'
nodes['backward'].title = 'Backward'

# Another partition -- but this time the dimension is the "type"
# column of the flows table
class_flow_part = flo.Partition.Simple('type', class_list)

# Set the colours for the labels in the partition.
palette = {str(cl) : col for cl, col in zip(class_list, color_list)}

# New SDD with the flow_partition set
sdd = flo.SankeyDefinition(nodes, bundles, ordering,
                       flow_partition=class_flow_part)

w = flo.weave(sdd, df_flows, palette=palette)

#%% save flow and render svg
import json

json_file = os.path.join('../figures/paper_figures/','ortho_flow_slow.json')
svg_file = os.path.join('../figures/paper_figures/','ortho_flow_slow.svg')

json_file = os.path.join('../figures/paper_figures/','ortho_flow_fast.json')
svg_file = os.path.join('../figures/paper_figures/','ortho_flow_fast.svg')

class NpEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, np.integer):
            return int(obj)
        if isinstance(obj, np.floating):
            return float(obj)
        if isinstance(obj, np.ndarray):
            return obj.tolist()
        return super(NpEncoder, self).default(obj)
    
    
with open(json_file, 'w') as fopen:
    json.dump(w.to_json(), fopen, cls=NpEncoder)
    


os.system(f'svg-sankey --size 250,200 --font-size 10 --margins 2,2 {json_file}  > {svg_file}')    


#%% final figure for paper
import matplotlib as mpl
plt.style.use('alex_paper')

mpl.rcParams['lines.linewidth'] = 1.1
mpl.rcParams['font.size'] = 12
        
from matplotlib.gridspec import GridSpec

from cycler import cycler

markersize=4
marker='o-'
sim=0

legend_fontsize = 10

colors = [ "#359455","#613D94","#D0730F", '#D13204', ]
default_cycler = (cycler(color=colors))
plt.rc('axes', prop_cycle=default_cycler)
                  
capsize=5
        
fig = plt.figure(figsize=(10,6))
gs = GridSpec(9, 9, figure=fig)

gs.update(wspace=1.0, hspace=0.2,
          left=0.1, right=0.95,
          top=0.95, bottom=0.1) # set the spacing between axes. 

ax00_ncf = fig.add_subplot(gs[:2,0:3])
ax01_nvif = fig.add_subplot(gs[2:4,0:3],sharex=ax00_ncf)
ax10_ncm = fig.add_subplot(gs[5:7,0:3],sharey=ax00_ncf)
ax11_nvim = fig.add_subplot(gs[7:9,0:3],sharey=ax01_nvif,sharex=ax10_ncm)
ax12_patm = fig.add_subplot(gs[5:9,3:6])
ax12_patm2 = fig.add_subplot(gs[5:9,6:9])
ax02_patf = fig.add_subplot(gs[0:4,3:6])
ax02_patf2 = fig.add_subplot(gs[0:4,6:9])

#num clust flowstab
taus = [1/l for l in res_flowstab['lambdas']]

ax00_ncf.errorbar(taus, [np.mean([res_flowstab['num_clusts_forw'][sim][i] for sim in range(num_sims)])\
                               for i in range(len(res_flowstab['lambdas']))],
             yerr=[np.std([res_flowstab['num_clusts_forw'][sim][i] for sim in range(num_sims)])\
                                            for i in range(len(res_flowstab['lambdas']))],
                 capsize=capsize,
                 label='forw.',
                 fmt=marker,
                 ms=markersize)
    
    
ax00_ncf.errorbar(taus, [np.mean([res_flowstab['num_clusts_back'][sim][i] for sim in range(num_sims)])\
                               for i in range(len(res_flowstab['lambdas']))],
             yerr=[np.std([res_flowstab['num_clusts_back'][sim][i] for sim in range(num_sims)])\
                                            for i in range(len(res_flowstab['lambdas']))],
                 capsize=capsize,
                 label='back.',
                 fmt=marker,
                 ms=markersize)
ax00_ncf.set_xscale('log')

# ax00_ncf.legend()
plt.setp(ax00_ncf.get_xticklabels(), visible=False)

#nmi flow stab
ax01_nvif.errorbar(taus, [np.mean([res_flowstab['avg_nmi_forw'][sim][i] for sim in range(num_sims)])\
                               for i in range(len(res_flowstab['lambdas']))],
             yerr=[np.std([res_flowstab['avg_nmi_forw'][sim][i] for sim in range(num_sims)])\
                                            for i in range(len(res_flowstab['lambdas']))],
                 capsize=capsize,
                 label='forw.',
                 fmt=marker,
                 ms=markersize)
    

ax01_nvif.errorbar(taus, [np.mean([res_flowstab['avg_nmi_back'][sim][i] for sim in range(num_sims)])\
                               for i in range(len(res_flowstab['lambdas']))],
             yerr=[np.std([res_flowstab['avg_nmi_back'][sim][i] for sim in range(num_sims)])\
                                            for i in range(len(res_flowstab['lambdas']))],
                 capsize=capsize,
                 label='back.',
                 fmt=marker,
                 ms=markersize)
ax01_nvif.set_xscale('log')
ax01_nvif.legend()

#mun clust multi
ax10_ncm.errorbar(multislice_res['res_params'][sim], [np.mean([multislice_res['num_clusts_multi'][sim][i] for sim in range(num_sims)])\
                               for i in range(len(multislice_res['res_params'][sim]))],
             yerr=[np.std([multislice_res['num_clusts_multi'][sim][i] for sim in range(num_sims)])\
                                            for i in range(len(multislice_res['res_params'][sim]))],
                 capsize=capsize,
                 fmt=marker,
                 ms=markersize)
ax10_ncm.set_xscale('log')
plt.setp(ax10_ncm.get_xticklabels(), visible=False)

#nvi multi
ax11_nvim.errorbar(multislice_res['res_params'][0], [np.mean([multislice_res['avg_nmi_multi'][sim][i] for sim in range(num_sims)])\
                               for i in range(len(multislice_res['res_params'][0]))],
             yerr=[np.std([multislice_res['avg_nmi_multi'][sim][i] for sim in range(num_sims)])\
                                            for i in range(len(multislice_res['res_params'][0]))],
                 capsize=capsize,
                 fmt=marker,
                 ms=markersize)
ax11_nvim.set_xscale('log')

# ax00_ncf.set_xlabel(r'$\tau_w$')
ax01_nvif.set_xlabel(r'$\tau_w$')
# ax10_ncm.set_xlabel(r'resolution')
ax11_nvim.set_xlabel(r'resolution')

ax00_ncf.set_ylabel('Num. clusters')
ax10_ncm.set_ylabel('Num. clusters')
ax01_nvif.set_ylabel('Norm. Var. Inf')
ax11_nvim.set_ylabel('Norm. Var. Inf')

# most frequent partition at minima of NVI
clust_matrix = np.array(multislice_res['best_clusts_multi'][1][6]).reshape((5,
                                                          nets[0].num_nodes)).T

# partition at minima of NVI with smallest avg NVI
clust_matrix2 = np.array(multislice_res['best_clusts_multi'][8][7]).reshape((5,
                                                          nets[0].num_nodes)).T

color_list =["#4ba706",
"#a2007e",
# "#806dcb",
# "#5eb275",
"#ca3b01",
"#01a4d6",
"#b77600",
"#a39643",
"#cc6ea9",
"#1e5e39",
"#cb5b5a"]
# cmap1 = mpl.colors.ListedColormap(color_list[:clust_matrix.max()+1])
# cmap2 = mpl.colors.ListedColormap(color_list[:clust_matrix2.max()+1])

cmap = mpl.colors.ListedColormap(colors)


ax12_patm.imshow(clust_matrix, aspect='auto', cmap=cmap)
ax12_patm.set_ylabel('nodes')
ax12_patm.set_xlabel('')
ax12_patm.set_title(f'resolution = {multislice_res["res_params"][0][6]:.2f}')
ax12_patm.set_yticks(range(0,8))
ax12_patm.set_yticklabels(range(1,9))
ax12_patm.set_xticks(range(0,5))
ax12_patm.set_xticklabels(range(1,6))

ax12_patm2.imshow(clust_matrix2, aspect='auto', cmap=cmap)
ax12_patm2.set_ylabel('nodes')
ax12_patm2.set_xlabel('time slices')
ax12_patm2.set_title(f'resolution = {multislice_res["res_params"][0][7]:.2f}')
ax12_patm2.set_yticks(range(0,8))
ax12_patm2.set_yticklabels(range(1,9))
ax12_patm2.set_xticks(range(0,5))
ax12_patm2.set_xticklabels(range(1,6))


ax00_ncf.text(-0.1,1.1, r'\textbf{A}', transform=ax00_ncf.transAxes, useTex=True)
ax01_nvif.text(-0.1,1.05, r'\textbf{B}', transform=ax01_nvif.transAxes, useTex=True)
ax02_patf.text(-0.1,1.05, r'\textbf{E}', transform=ax02_patf.transAxes, useTex=True)

ax10_ncm.text(-0.1,1.05, r'\textbf{C}', transform=ax10_ncm.transAxes, useTex=True)
ax11_nvim.text(-0.1,1.05, r'\textbf{D}', transform=ax11_nvim.transAxes, useTex=True)
ax12_patm.text(-0.1,1.05, r'\textbf{F}', transform=ax12_patm.transAxes, useTex=True)
ax02_patf2.text(-0.1,1.05, r'\textbf{G}', transform=ax02_patf2.transAxes, useTex=True)
ax12_patm2.text(-0.1,1.05, r'\textbf{H}', transform=ax12_patm2.transAxes, useTex=True)
#%% savefig

plt.savefig(os.path.join(figdir,'ortho_cont_compa.pdf'))
plt.savefig(os.path.join(figdir,'ortho_cont_compa.png'), dpi=600)
